In [79]:
%load_ext autoreload

# Reload modules imported using %aimport
%autoreload 1

import librosa
import matplotlib.pyplot as plt
import numpy as np
%aimport wavenet_features

# Define the sampling rate in Hz
SR = 16000

# Define the number of classes
C = 256

# Define the receptive field size
RF = 3070

from IPython.display import Audio, HTML, display
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload

Spectrograms of unit activations for various prefix frequencies¶

In [114]:
# Include all the whole tones from A3 to A4
freqs = np.logspace(np.log10(220), np.log10(440), 13)[::2]

# Generate for 20 ms
duration_ms = 20

def extract_channel_activations(layer_activations):
    """Extract the activations associated with the first channel of a layer."""
    
    # Flatten all but the first axis
    n_timesteps = len(layer_activations)
    layer_activations = layer_activations.reshape(n_timesteps, -1)
    
    # Arbitrarily select the first channel
    return layer_activations[:, 0]

examples, freq_activations = wavenet_features.generate_activations(
    freqs, duration_ms, sample_rate=SR, n_classes=C, f=extract_channel_activations,
    snapshot_path='snapshots/chromatic_2022-06-30_14-45-48')
100%|██████████████████████████████████████████████████████████████████████████████| 320/320 [03:50<00:00,  1.39it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 320/320 [04:33<00:00,  1.17it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 320/320 [03:11<00:00,  1.67it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 320/320 [03:57<00:00,  1.35it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 320/320 [03:11<00:00,  1.67it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 320/320 [03:48<00:00,  1.40it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 320/320 [03:16<00:00,  1.63it/s]
In [ ]:
n_freqs, n_layers, n_timesteps = freq_activations.shape

for layer_id in range(n_layers):
    subplot_aspect = 0.3
    fig_width = 9
    fig_height = subplot_aspect * fig_width
    fig, axes = plt.subplots(2, 2 * n_freqs, dpi=150, figsize=(fig_width, fig_height))
    fig.suptitle(f'Layer {layer_id + 1}')
    
    for ax in axes.flatten():
        ax.set_xticks([])
        ax.set_yticks([])
    
    for freq_id, freq in enumerate(freqs):
        axes[0, 2 * freq_id].set_title('Generated\n(%.1f Hz)' % freq, fontsize=6)
        axes[0, 2 * freq_id + 1].set_title('Activations\n(%.1f Hz)' % freq, fontsize=6)
        
        # Plot the generated and activation waveforms
        axes[0, 2 * freq_id].plot(examples[freq_id][-n_timesteps:])
        axes[0, 2 * freq_id + 1].plot(freq_activations[freq_id, layer_id])
        
        # Plot the generated and activation spectrograms
        axes[1, 2 * freq_id].specgram(examples[freq_id][-n_timesteps:], Fs=SR, NFFT=150)
        axes[1, 2 * freq_id + 1].specgram(freq_activations[freq_id, layer_id], Fs=SR, NFFT=150)
    
    fig.tight_layout()
    plt.show()
In [ ]: